import argparse
import os
import json
import hashlib
import heapq
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

torch.set_grad_enabled(False)

def _stable_int(s: str) -> int:
    return int.from_bytes(hashlib.sha256(s.encode()).digest()[:4], "little")

def first_tensor(obj):
    if isinstance(obj, dict):
        for v in obj.values():
            if isinstance(v, torch.Tensor):
                return v
    return obj if isinstance(obj, torch.Tensor) else None

def load_data(path, device="cpu"):
    if not path or not os.path.exists(path):
        return None
    return first_tensor(torch.load(path, map_location=device))

def load_diag(hdir: str, layer_idx: int, tag: str, out_dim: int, in_dim: int):
    if not hdir:
        return None
    p1 = f"{hdir}/{layer_idx}_{tag}.pt"
    if os.path.exists(p1):
        H = first_tensor(torch.load(p1, map_location="cpu"))
    else:
        p2 = f"{hdir}/{layer_idx}_qkv.pt"
        if not os.path.exists(p2):
            return None
        obj = torch.load(p2, map_location="cpu")
        if isinstance(obj, dict) and tag in obj:
            H = obj[tag]
        else:
            H = first_tensor(obj)
    if H is None:
        return None
    if H.ndim == 2:
        return H.diag().contiguous()
    if H.ndim == 1:
        n = H.numel()
        if n == out_dim:
            return H.contiguous()
        if in_dim and n == out_dim * in_dim:
            return H.view(out_dim, in_dim).sum(dim=1).contiguous()
        if n == 3 * out_dim:
            return H.view(3, out_dim)[{"q":0,"k":1,"v":2}[tag]].contiguous()
    return None

def _minmax_norm(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    x = x.float()
    maxv = torch.amax(x)
    minv = torch.amin(x)
    rng = (maxv - minv).clamp_min(eps)
    return (x - minv) / rng

def parse_kv_map_int(s: str) -> Dict[str, int]:
    if not s:
        return {}
    return {k.strip(): int(v) for k, v in (p.split(":") for p in s.split(","))}

def parse_blocks(s: str) -> List[str]:
    full = ["q","k","v","o","up","down","gate","embed","lm_head"]
    if s.lower() == "all":
        return [b for b in full if b not in ("embed","lm_head")]
    if s.lower() == "all+io":
        return full
    return [x.strip() for x in s.split(',') if x.strip() in full]

def _make_block_ortho(n, block, gen, device):
    mats = []
    i = 0
    while i < n:
        b = min(block, n - i)
        A = torch.randn(b, b, generator=gen, device=device, dtype=torch.float32)
        Q, _ = torch.linalg.qr(A, mode='reduced')
        mats.append(Q)
        i += b
    return mats

@torch.no_grad()
def apply_precond(W, row_block, col_block, seed):
    device = W.device
    gen = torch.Generator(device=device)
    gen.manual_seed(int(seed))
    out_dim, in_dim = W.shape
    d_row = torch.randint(0, 2, (out_dim,), generator=gen, device=device, dtype=torch.int8).float()*2-1
    d_col = torch.randint(0, 2, (in_dim,),  generator=gen, device=device, dtype=torch.int8).float()*2-1
    ORs = _make_block_ortho(out_dim, row_block, gen, device)
    OCs = _make_block_ortho(in_dim,  col_block, gen, device)
    W = (d_row.view(-1,1) * W)
    r0 = 0
    for OR in ORs:
        r1 = r0 + OR.shape[0]
        W[r0:r1, :] = OR @ W[r0:r1, :]
        r0 = r1
    W = W * d_col.view(1,-1)
    c0 = 0
    for OC in OCs:
        c1 = c0 + OC.shape[0]
        W[:, c0:c1] = W[:, c0:c1] @ OC
        c0 = c1
    @torch.no_grad()
    def inv(Wp):
        c0 = 0
        for OC in OCs:
            c1 = c0 + OC.shape[0]
            Wp[:, c0:c1] = Wp[:, c0:c1] @ OC.T
            c0 = c1
        Wp = Wp * d_col.view(1,-1)
        r0 = 0
        for OR in ORs:
            r1 = r0 + OR.shape[0]
            Wp[r0:r1, :] = OR.T @ Wp[r0:r1, :]
            r0 = r1
        Wp = d_row.view(-1,1) * Wp
        return Wp
    return W, inv

@torch.no_grad()
def quantize_per_row_symmetric(W: torch.Tensor, bits: int, q2_mode: str = "zero", q2_lsq_iters: int = 0, jetson_static_scale: bool = False):
    Wf = W.float()
    if bits == 2:
        codebooks = {"sym": (torch.tensor([-2, -1, 1, 2], device=Wf.device, dtype=Wf.dtype), 2.0), "pm3": (torch.tensor([-3, -1, 1, 3], device=Wf.device, dtype=Wf.dtype), 3.0), "zero": (torch.tensor([-2, -1, 0, 1], device=Wf.device, dtype=Wf.dtype), 2.0)}
        code, qmax = codebooks[q2_mode]
        s = torch.amax(Wf.abs(), dim=1, keepdim=True).clamp_min(1e-12) / qmax
        for _ in range(max(0, int(q2_lsq_iters))):
            ratio = (Wf / s).unsqueeze(-1)
            idx = torch.argmin((ratio - code)**2, dim=-1)
            qf = code[idx]
            denom = (qf*qf).sum(dim=1, keepdim=True).clamp_min(1e-12)
            s = ((Wf*qf).sum(dim=1, keepdim=True) / denom).clamp_(min=1e-6, max=1e3)
        ratio = (Wf / s).unsqueeze(-1)
        idx = torch.argmin((ratio - code)**2, dim=-1)
        q = code[idx].to(torch.int8)
        return q, s.squeeze(1).contiguous()
    qmax = 127.0 if bits == 8 else 7.0
    if jetson_static_scale:
        scale = torch.amax(Wf.abs()).clamp_min(1e-12) / qmax
        scale = scale.expand(Wf.shape[0])
    else:
        scale = torch.amax(Wf.abs(), dim=1, keepdim=True).clamp_min(1e-12) / qmax
        scale = scale.squeeze(1)
    q = torch.clamp((Wf / scale.view(-1, 1)).round(), min=-qmax, max=qmax).to(torch.int8)
    return q, scale.contiguous()

@torch.no_grad()
def dequantize_per_row_symmetric(q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return q.float() * scale.view(-1, 1)

@torch.no_grad()
def caldera_q_lr_reconstruct(W: torch.Tensor, rank_k: int, q_bits: int, lr_bits: int, fold: bool, norm_fix: bool, args_for_q: dict):
    quant_args = dict(args_for_q)
    q_main, scale_main = quantize_per_row_symmetric(W, bits=q_bits, **quant_args)
    Wq = dequantize_per_row_symmetric(q_main, scale_main)
    R = (W.float() - Wq).contiguous()
    rmax = min(rank_k, R.shape[0], R.shape[1])
    if rmax <= 0:
        W_hat = Wq
    else:
        U, S, Vt = torch.linalg.svd(R, full_matrices=False)
        A = U[:, :rmax] @ torch.diag(S[:rmax])
        B = Vt[:rmax, :]
        if lr_bits >= 16:
            LR = A @ B
        else:
            A_q, A_s = quantize_per_row_symmetric(A, bits=lr_bits, **quant_args)
            B_q, B_s = quantize_per_row_symmetric(B, bits=lr_bits, **quant_args)
            LR = dequantize_per_row_symmetric(A_q, A_s) @ dequantize_per_row_symmetric(B_q, B_s)
        W_hat = Wq + LR
    if fold and norm_fix:
        orig_norm = W.float().norm().clamp_min(1e-12)
        reco_norm = W_hat.float().norm().clamp_min(1e-12)
        W_hat = W_hat * (orig_norm / reco_norm)
    return W_hat.to(W.dtype) if fold else W

@torch.no_grad()
def singular_scores(W, G, H_diag, alpha, beta, gamma):
    U, S, Vt = torch.linalg.svd(W.float(), full_matrices=False)
    r = S.numel()
    if r == 0:
        return U, S, Vt, torch.empty(0, device=W.device), 0
    parts = []
    if G is not None and alpha > 0:
        parts.append(alpha * _minmax_norm(torch.abs(torch.diag(U.T@G.float()@Vt.T))))
    if H_diag is not None and beta > 0:
        parts.append(beta * _minmax_norm((U.pow(2).T @ H_diag.float())))
    if gamma > 0:
        parts.append(gamma * (S / S.sum().clamp_min(1e-12)))
    if parts:
        score = parts[0]
        for p in parts[1:]:
            score = score + p
    else:
        score = S / S.sum().clamp_min(1e-12)
    return U, S, Vt, score, r

def lr_marginal_cost(out_dim, in_dim, lr_bits, scale_bits, first=False):
    cost_vals = (out_dim + in_dim) * lr_bits
    if lr_bits >= 16:
        return cost_vals
    cost_scales = scale_bits
    if first:
        cost_scales += out_dim * scale_bits
    return cost_vals + cost_scales

def global_allocate_ks(block_infos, budget_avg_bits, lr_bits, scale_bits, allow_zero_rank, min_k_map, rank_cap_map):
    total_params = sum(int(b["shape"][0]) * int(b["shape"][1]) for b in block_infos)
    total_budget_bits = budget_avg_bits * total_params
    baseline_bits = 0
    for b in block_infos:
        qb = int(b.get("q_bits", 4))
        out_dim, in_dim = b["shape"]
        baseline_bits += out_dim * in_dim * qb + out_dim * scale_bits
    remain_bits = max(0, total_budget_bits - baseline_bits)
    def k0_for(b):
        tag = b["tag"]
        kmin = int(min_k_map.get(tag, 0 if allow_zero_rank else 1))
        return max(kmin, 0 if allow_zero_rank else 1)
    ks = [k0_for(b) for b in block_infos]
    def inc_cost(bi, first):
        out_dim, in_dim = block_infos[bi]["shape"]
        return lr_marginal_cost(out_dim, in_dim, lr_bits, scale_bits, first=first)
    def dec_cost(bi, k):
        return inc_cost(bi, True) if k == 1 else inc_cost(bi, False)
    used_bits = sum(inc_cost(bi, True) + (ks[bi] - 1) * inc_cost(bi, False) for bi, b in enumerate(block_infos) if ks[bi] > 0)
    if used_bits > remain_bits:
        shrink_heap = []
        for bi, b in enumerate(block_infos):
            k_now = ks[bi]
            k_min = int(min_k_map.get(b["tag"], 0 if allow_zero_rank else 1))
            if k_now > k_min:
                score_vec = b["score"][b["idx_order"]]
                last_gain = float(score_vec[k_now - 1])
                last_cost = dec_cost(bi, k_now)
                heapq.heappush(shrink_heap, (last_gain / max(last_cost, 1e-9), bi))
        while used_bits > remain_bits and shrink_heap:
            _, bi = heapq.heappop(shrink_heap)
            b = block_infos[bi]
            k_now = ks[bi]
            k_min = int(min_k_map.get(b["tag"], 0 if allow_zero_rank else 1))
            if k_now <= k_min:
                continue
            saved = dec_cost(bi, k_now)
            ks[bi] -= 1
            used_bits -= saved
            if ks[bi] > k_min:
                score_vec = b["score"][b["idx_order"]]
                last_gain = float(score_vec[ks[bi] - 1])
                last_cost = dec_cost(bi, ks[bi])
                heapq.heappush(shrink_heap, (last_gain / max(last_cost, 1e-9), bi))
        if used_bits > remain_bits:
            raise ValueError("Budget too tight after recycling")
    heap = []
    paid_a_scale = [k > 0 for k in ks]
    for bi, b in enumerate(block_infos):
        max_rank = min(b["S"].numel(), int(rank_cap_map.get(b["tag"], b["S"].numel())))
        if ks[bi] < max_rank:
            gain = float(b["score"][b["idx_order"]][ks[bi]])
            cost = inc_cost(bi, first=(ks[bi] == 0 and not paid_a_scale[bi]))
            heapq.heappush(heap, (-(gain / max(cost, 1e-9)), bi))
    while heap:
        _, bi = heapq.heappop(heap)
        b = block_infos[bi]
        max_rank = min(b["S"].numel(), int(rank_cap_map.get(b["tag"], b["S"].numel())))
        if ks[bi] >= max_rank:
            continue
        cost = inc_cost(bi, first=(ks[bi] == 0 and not paid_a_scale[bi]))
        if used_bits + cost > remain_bits:
            continue
        if ks[bi] == 0:
            paid_a_scale[bi] = True
        ks[bi] += 1
        used_bits += cost
        if ks[bi] < max_rank:
            next_gain = float(b["score"][b["idx_order"]][ks[bi]])
            next_cost = inc_cost(bi, first=False)
            heapq.heappush(heap, (-(next_gain / max(next_cost, 1e-9)), bi))
    return ks, used_bits, baseline_bits, total_params

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_model_dir", required=True)
    ap.add_argument("--grad_dir", required=True)
    ap.add_argument("--hess_dir", default=None)
    ap.add_argument("--out_dir", required=True)
    ap.add_argument("--budget_avg_bits", type=float, required=True)
    ap.add_argument("--q_bits", type=int, default=4, choices=[2,4,8])
    ap.add_argument("--q_bits_map", type=str, default="")
    ap.add_argument("--lr_bits", type=int, default=4, choices=[4,8,16])
    ap.add_argument("--scale_bits", type=int, default=16, choices=[8,16,32])
    ap.add_argument("--allow_zero_rank", action="store_true")
    ap.add_argument("--min_k_map", type=str, default="")
    ap.add_argument("--rank_cap_map", type=str, default="")
    ap.add_argument("--alpha", type=float, default=1.0)
    ap.add_argument("--beta", type=float, default=0.0)
    ap.add_argument("--gamma", type=float, default=1.0)
    ap.add_argument("--blocks", type=str, default="all")
    ap.add_argument("--norm_fix", dest="norm_fix", action="store_true")
    ap.add_argument("--no-norm_fix", dest="norm_fix", action="store_false")
    ap.set_defaults(norm_fix=True)
    ap.add_argument("--score_device", choices=["cpu","cuda"], default="cpu")
    ap.add_argument("--reconstruct_device", choices=["cpu","cuda"], default="cuda")
    ap.add_argument("--q2_mode", choices=["zero","sym","pm3"], default="zero")
    ap.add_argument("--q2_lsq_iters", type=int, default=8)
    ap.add_argument("--rht_row_block", type=int, default=256)
    ap.add_argument("--rht_col_block", type=int, default=256)
    ap.add_argument("--rht_seed", type=int, default=1234)
    args = ap.parse_args()
    min_k_map = parse_kv_map_int(args.min_k_map)
    rank_cap_map = parse_kv_map_int(args.rank_cap_map)
    q_bits_map = parse_kv_map_int(args.q_bits_map)
    model = AutoModelForCausalLM.from_pretrained(args.base_model_dir, torch_dtype="auto", low_cpu_mem_usage=True)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_dir)
    blocks = parse_blocks(args.blocks)
    tag2mod = lambda l: {"q":l.self_attn.q_proj, "k":l.self_attn.k_proj, "v":l.self_attn.v_proj, "o":l.self_attn.o_proj, "up":l.mlp.up_proj, "down":l.mlp.down_proj, "gate": getattr(l.mlp, 'gate_proj', None)}
    block_infos = []
    for i, layer in enumerate(tqdm(model.model.layers, desc="Scoring Layers")):
        for tag, mod in tag2mod(layer).items():
            if not mod or tag not in blocks:
                continue
            q_bits_used = int(q_bits_map.get(tag, args.q_bits))
            W = mod.weight.detach().clone().to(args.score_device)
            G = load_data(os.path.join(args.grad_dir, f"{i}_{tag}.pt"), device=args.score_device)
            if G is not None and G.shape != W.shape:
                G = None
            H_diag = load_diag(args.hess_dir, i, tag, W.shape[0], W.shape[1])
            if H_diag is not None:
                H_diag = H_diag.to(args.score_device)
            if q_bits_used == 2:
                seed = (_stable_int(f"{i}_{tag}") ^ args.rht_seed) & 0xFFFFFFFF
                W_pre, _ = apply_precond(W, args.rht_row_block, args.rht_col_block, seed)
                G_pre = None if G is None else apply_precond(G, args.rht_row_block, args.rht_col_block, seed)[0]
            else:
                W_pre, G_pre = W, G
            _, S, _, score, _ = singular_scores(W_pre, G_pre, H_diag, args.alpha, args.beta, args.gamma)
            block_infos.append({"name": f"L{i:02d}_{tag}", "tag": tag, "shape": W.shape, "S": S.cpu(), "score": score.cpu(), "idx_order": torch.argsort(score, descending=True).cpu(), "q_bits": q_bits_used})
            del W, G, H_diag, S, score
            if args.score_device == "cuda":
                torch.cuda.empty_cache()
    if "embed" in blocks:
        emb = model.model.embed_tokens
        W = emb.weight.detach().clone().to(args.score_device)
        _, S, _, score, _ = singular_scores(W, None, None, 0, 0, 1.0)
        block_infos.append({"name": "EMBED", "tag": "embed", "shape": W.shape, "S": S.cpu(), "score": score.cpu(), "idx_order": torch.argsort(score, descending=True).cpu(), "q_bits": int(q_bits_map.get("embed", args.q_bits))})
        del W, S, score
    if "lm_head" in blocks and hasattr(model, 'lm_head') and model.lm_head is not None and hasattr(model.lm_head, 'weight'):
        head = model.lm_head
        W = head.weight.detach().clone().to(args.score_device)
        _, S, _, score, _ = singular_scores(W, None, None, 0, 0, 1.0)
        block_infos.append({"name": "LM_HEAD", "tag": "lm_head", "shape": W.shape, "S": S.cpu(), "score": score.cpu(), "idx_order": torch.argsort(score, descending=True).cpu(), "q_bits": int(q_bits_map.get("lm_head", args.q_bits))})
        del W, S, score
    if args.score_device == "cuda":
        torch.cuda.empty_cache()
    final_ks, used_lr_bits, baseline_bits, total_params = global_allocate_ks(block_infos, args.budget_avg_bits, args.lr_bits, args.scale_bits, args.allow_zero_rank, min_k_map, rank_cap_map)
    realized_avg_bits = (baseline_bits + used_lr_bits) / total_params
    name_to_idx = {info["name"]: i for i, info in enumerate(block_infos)}
    for i, layer in enumerate(tqdm(model.model.layers, desc="Reconstructing")):
        for tag, mod in tag2mod(layer).items():
            if not mod or tag not in blocks:
                continue
            info_idx = name_to_idx.get(f"L{i:02d}_{tag}")
            if info_idx is None:
                continue
            k = final_ks[info_idx]
            q_bits_used = int(block_infos[info_idx].get("q_bits", args.q_bits))
            W_on_device = mod.weight.detach().to(args.reconstruct_device)
            use_precond = (q_bits_used == 2)
            if use_precond:
                seed = (_stable_int(f"{i}_{tag}") ^ args.rht_seed) & 0xFFFFFFFF
                W_pre, inv = apply_precond(W_on_device.clone(), args.rht_row_block, args.rht_col_block, seed)
            else:
                W_pre, inv = W_on_device, (lambda x: x)
            quant_args = {"q2_mode": args.q2_mode, "q2_lsq_iters": args.q2_lsq_iters}
            W_pre_new = caldera_q_lr_reconstruct(W_pre, rank_k=k, q_bits=q_bits_used, lr_bits=args.lr_bits, fold=True, norm_fix=args.norm_fix, args_for_q=quant_args)
            W_new = inv(W_pre_new)
            mod.weight.data.copy_(W_new.to(mod.weight.dtype).cpu())
            del W_on_device, W_pre, inv
            if args.reconstruct_device == "cuda":
                torch.cuda.empty_cache()
    os.makedirs(args.out_dir, exist_ok=True)
    if hasattr(model, 'tie_weights'):
        try:
            model.tie_weights()
        except Exception:
            pass
    model.save_pretrained(args.out_dir)
    tokenizer.save_pretrained(args.out_dir)
    meta = {"budget_avg_bits": args.budget_avg_bits, "realized_avg_bits": float(realized_avg_bits), "q_bits": args.q_bits, "lr_bits": args.lr_bits, "scale_bits": args.scale_bits, "total_params": int(total_params), "baseline_bits": float(baseline_bits), "used_lr_bits": float(used_lr_bits), "method": "teacher_guided_rank_selection", "alpha": args.alpha, "beta": args.beta, "gamma": args.gamma, "ranks": [{"name": info["name"], "tag": info["tag"], "out": int(info["shape"][0]), "in": int(info["shape"][1]), "k": int(k), "q_bits": int(info.get("q_bits", args.q_bits))} for info, k in zip(block_infos, final_ks)]}
    with open(f"{args.out_dir}/quant_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

if __name__ == "__main__":
    main()
